import re
from pprint import pprint
import numpy as np

'''
分析log文件里的测试结果，将每个epoch的测试结果提取出来，并按dice排序
'''

# log_path="/home/liukai/projects/SIFA/log/train4label.log"
log_path="/home/liukai/projects/SIFA/log/train4labelWithAugment.log"
with open(log_path, 'r') as f:
    context = f.read()

context=context.replace('\n','')

result=re.findall('epoch:.*?-----------',context)
result_dic={}
for epoch_result in result:
    dice_list=[]
    epoch=int(re.findall("epoch:[0-9]+",epoch_result)[0][6:])
    dice_mean_str=re.findall("Dice.*?]",epoch_result)[0][11:-1]
    dice_mean_str=dice_mean_str.split(' ')
    for s in dice_mean_str:
        if s!='':
            dice_list.append(float(s))

    result_dic[epoch]=np.array(dice_list[1:])

result_after_sorter=sorted(result_dic.items(), key=lambda x: x[1].mean(),reverse=True)

for k,v in result_after_sorter:
    print(k,v.mean(),result_dic[k])